{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7d105473-c0c6-4de6-acfb-ccf3054fd1a0",
   "metadata": {},
   "source": [
    "# Tool Pattern"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39e9c48d-cac5-48a1-a6ff-e44b618b92c9",
   "metadata": {},
   "source": [
    "source : https://github.com/neural-maze/agentic_patterns\n",
    "\n",
    "---\n",
    "\n",
    "As you may already know, the information stored in LLM weights is (usually) 𝐧𝐨𝐭 𝐞𝐧𝐨𝐮𝐠𝐡 to give accurate and insightful answers to our questions.\n",
    " \n",
    "That's why we need to provide the LLM with ways to access the outside world. 🌍 \n",
    "\n",
    "In practice, you can build tools for whatever you want (at the end of the day they are just functions the LLM can use), from a tool that let's you access Wikipedia, another to analyse the content of YouTube videos or calculate difficult integrals using Wolfram Alpha. \n",
    "\n",
    "The second pattern we are going to implement is the **tool pattern**. \n",
    "\n",
    "In this notebook, you'll learn how **tools** actually work. This is the **second lesson** of the \"Agentic Patterns from Scratch\" series. Take a look at the first lesson if you haven't!\n",
    "\n",
    "* [First Lesson: The Reflection Pattern](https://github.com/neural-maze/agentic_patterns/blob/main/notebooks/reflection_pattern.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6eb2bab-9a5b-4c92-b23a-18f757d44c06",
   "metadata": {},
   "source": [
    "## A simple function"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "148df24a-4ac5-4d3d-9860-8ff0e7ed7c90",
   "metadata": {},
   "source": [
    "Take a look at this function 👇"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c851271-9b5a-4b48-a0e0-bf889cfb303b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "def get_current_weather(location: str, unit: str):\n",
    "\t\"\"\"\n",
    "\tGet the current weather in a given location\n",
    "\n",
    "\tlocation (str): The city and state, e.g. Madrid, Barcelona\n",
    "\tunit (str): The unit. It can take two values; \"celsius\", \"fahrenheit\"\n",
    "\t\"\"\"\n",
    "\tif location == \"Madrid\":\n",
    "\t\treturn json.dumps({\"temperature\": 25, \"unit\": unit})\n",
    "\n",
    "\telse:\n",
    "\t\treturn json.dumps({\"temperature\": 58, \"unit\": unit})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de31cb35-847f-458f-b7d7-603acf5a714a",
   "metadata": {},
   "source": [
    "Very simple, right? You provide a `location` and a `unit` and it returns the temperature."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f52e61e-be31-4e6f-9f4f-eeb7082ad827",
   "metadata": {},
   "outputs": [],
   "source": [
    "get_current_weather(location=\"Madrid\", unit=\"celsius\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9d63a34-8a93-4551-a34a-a0e85c95aa6a",
   "metadata": {},
   "source": [
    "But the question is:\n",
    "\n",
    "**How can you make this function available to an LLM?**\n",
    "\n",
    "An LLM is a type of NLP system, so it expects text as input. But how can we transform this function into text?"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56a4f2f8-9fc2-4e3d-87cd-bdfca15e5ddc",
   "metadata": {},
   "source": [
    "## A System Prompt that works"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93bed242-75ca-4ab7-a159-114a9e1e7e67",
   "metadata": {},
   "source": [
    "For the LLM to be aware of this function, we need to provide some relevant information about it in the context. **I'm referring to the function name, attributes, description, etc.** Take a look at the following System Prompt."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ad89df0-d233-41cc-b002-19963e7740a1",
   "metadata": {},
   "source": [
    "```xml\n",
    "You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. \n",
    "You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug \n",
    "into functions. Pay special attention to the properties 'types'. You should use those types as in a Python dict.\n",
    "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n",
    "\n",
    "<tool_call>\n",
    "{\"name\": <function-name>,\"arguments\": <args-dict>}\n",
    "</tool_call>\n",
    "\n",
    "Here are the available tools:\n",
    "\n",
    "<tools> {\n",
    "    \"name\": \"get_current_weather\",\n",
    "    \"description\": \"Get the current weather in a given location location (str): The city and state, e.g. Madrid, Barcelona unit (str): The unit. It can take two values; 'celsius', 'fahrenheit'\",\n",
    "    \"parameters\": {\n",
    "        \"properties\": {\n",
    "            \"location\": {\n",
    "                \"type\": \"string\"\n",
    "            },\n",
    "            \"unit\": {\n",
    "                \"type\": \"string\"\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "}\n",
    "</tools>\n",
    "```\n",
    "\n",
    "\n",
    "As you can see, the LLM enforces the LLM to behave as a `function calling AI model` who, given a list of function signatures inside the <tools></tools> XML tags\n",
    "will select which one to use. When the model decides a function to use, it will return a json like the following, representing a function call:\n",
    "\n",
    "```xml\n",
    "<tool_call>\n",
    "{\"name\": <function-name>,\"arguments\": <args-dict>}\n",
    "</tool_call>\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d2d8322-afc0-4469-90aa-23019bc929e7",
   "metadata": {},
   "source": [
    "Let's see how it works in practise! 👇"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "692b5c16-77f3-4de0-b2b5-16bfc5812b7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "from groq import Groq\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "# Remember to load the environment variables. You should have the Groq API Key in there :)\n",
    "load_dotenv()\n",
    "\n",
    "MODEL = \"llama3-groq-70b-8192-tool-use-preview\"\n",
    "GROQ_CLIENT = Groq()\n",
    "\n",
    "# Define the System Prompt as a constant\n",
    "TOOL_SYSTEM_PROMPT = \"\"\"\n",
    "You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. \n",
    "You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug \n",
    "into functions. Pay special attention to the properties 'types'. You should use those types as in a Python dict.\n",
    "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n",
    "\n",
    "<tool_call>\n",
    "{\"name\": <function-name>,\"arguments\": <args-dict>}\n",
    "</tool_call>\n",
    "\n",
    "Here are the available tools:\n",
    "\n",
    "<tools> {\n",
    "    \"name\": \"get_current_weather\",\n",
    "    \"description\": \"Get the current weather in a given location location (str): The city and state, e.g. Madrid, Barcelona unit (str): The unit. It can take two values; 'celsius', 'fahrenheit'\",\n",
    "    \"parameters\": {\n",
    "        \"properties\": {\n",
    "            \"location\": {\n",
    "                \"type\": \"str\"\n",
    "            },\n",
    "            \"unit\": {\n",
    "                \"type\": \"str\"\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "}\n",
    "</tools>\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0da45c0-0b4b-4153-83c7-eed1c312dcec",
   "metadata": {},
   "source": [
    "Let's ask a very simple question: `\"What's the current temperature in Madrid, in Celsius?\"`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e00b09e8-55d3-4a59-a9cf-29329af78d9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "tool_chat_history = [\n",
    "    {\n",
    "        \"role\": \"system\",\n",
    "        \"content\": TOOL_SYSTEM_PROMPT\n",
    "    }\n",
    "]\n",
    "agent_chat_history = []\n",
    "\n",
    "user_msg = {\n",
    "    \"role\": \"user\",\n",
    "    \"content\": \"What's the current temperature in Madrid, in Celsius?\"\n",
    "}\n",
    "\n",
    "tool_chat_history.append(user_msg)\n",
    "agent_chat_history.append(user_msg)\n",
    "\n",
    "output = GROQ_CLIENT.chat.completions.create(\n",
    "    messages=tool_chat_history,\n",
    "    model=MODEL\n",
    ").choices[0].message.content\n",
    "\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c26cf72-0b60-464e-9f83-af371a93b3d5",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "**That's an improvement!** We may not have the *proper* answer but, with this information, we can obtain it! How? Well, we just need to:\n",
    "\n",
    "1. Parse the LLM output. By this I mean deleting the XML tags\n",
    "2. Load the output as a proper Python dict\n",
    "\n",
    "The function below does exactly this.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4366ae38-055a-45ec-937b-dfec7eaad00b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_tool_call_str(tool_call_str: str):\n",
    "    pattern = r'</?tool_call>'\n",
    "    clean_tags = re.sub(pattern, '', tool_call_str)\n",
    "    \n",
    "    try:\n",
    "        tool_call_json = json.loads(clean_tags)\n",
    "        return tool_call_json\n",
    "    except json.JSONDecodeError:\n",
    "        return clean_tags\n",
    "    except Exception as e:\n",
    "        print(f\"Unexpected error: {e}\")\n",
    "        return \"There was some error parsing the Tool's output\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5890ba4-3f2f-4dc8-9a62-dff0079f07bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "parsed_output = parse_tool_call_str(output)\n",
    "parsed_output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "944b0373-f647-423a-bf00-914ffb03dcd7",
   "metadata": {},
   "source": [
    "We can simply run the function now, by passing the arguments like this 👇"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "169f06bb-836d-4270-bd66-abc2aadc0757",
   "metadata": {},
   "outputs": [],
   "source": [
    "result = get_current_weather(**parsed_output[\"arguments\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecdfbbc5-7cdf-4c21-8b75-055446658675",
   "metadata": {},
   "outputs": [],
   "source": [
    "result"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "272a337d-c193-4316-bed5-bc1ee4ccaae5",
   "metadata": {},
   "source": [
    "**That's it!** A temperature of 25 degrees Celsius. \n",
    "\n",
    "As you can see, we're dealing with a string, so we can simply add the parsed_output to the `chat_history` so that the LLM knows the information it has to return to the user. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fb0fc08-dad9-42cd-a2a9-674b8191d06b",
   "metadata": {},
   "outputs": [],
   "source": [
    "agent_chat_history.append({\n",
    "    \"role\": \"user\",\n",
    "    \"content\": f\"Observation: {result}\"\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b610fb1f-24af-4cc1-b485-fa0c5bfca846",
   "metadata": {},
   "outputs": [],
   "source": [
    "GROQ_CLIENT.chat.completions.create(\n",
    "    messages=agent_chat_history,\n",
    "    model=MODEL\n",
    ").choices[0].message.content"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72fa386e-edef-4e3f-903d-a2fc7008e5c3",
   "metadata": {},
   "source": [
    "## Implementing everything the good way"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4217eb34-efac-4a05-bb23-ae780126c0ad",
   "metadata": {},
   "source": [
    "To recap, we have a way for the LLM to generate `tool_calls` that we can use later to *properly* run the functions. But, as you may imagine, there are some pieces missing:\n",
    "\n",
    "1. We need to automatically transform any function into a description like we saw in the initial system prompt.\n",
    "2. We need a way to tell the agent that this function is a tool\n",
    "\n",
    "Let's do it!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df20db23-3c1a-4744-88b8-8d47d7875f18",
   "metadata": {},
   "source": [
    "### The `tool` decorator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c538804-a381-4552-94eb-c04720e897df",
   "metadata": {},
   "source": [
    "We are going to use the `tool` decorator to transform any Python function into a tool. You can see the implementation [here](https://github.com/neural-maze/agentic_patterns/blob/main/src/agentic_patterns/tool_pattern/tool.py). To test it out, let's make a more complex tool than before. For example, a tool that interacts with [Hacker News](https://news.ycombinator.com/), getting the current top stories. \n",
    "\n",
    "> Reminder: To automatically generate the function signature for the tool, we need a way to infer the arguments types. For this reason, we need to create the typing annotations. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "222d9bf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from typing import Callable\n",
    "\n",
    "\n",
    "def get_fn_signature(fn: Callable) -> dict:\n",
    "    \"\"\"\n",
    "    Generates the signature for a given function.\n",
    "\n",
    "    Args:\n",
    "        fn (Callable): The function whose signature needs to be extracted.\n",
    "\n",
    "    Returns:\n",
    "        dict: A dictionary containing the function's name, description,\n",
    "              and parameter types.\n",
    "    \"\"\"\n",
    "    fn_signature: dict = {\n",
    "        \"name\": fn.__name__,\n",
    "        \"description\": fn.__doc__,\n",
    "        \"parameters\": {\"properties\": {}},\n",
    "    }\n",
    "    schema = {\n",
    "        k: {\"type\": v.__name__} for k, v in fn.__annotations__.items() if k != \"return\"\n",
    "    }\n",
    "    fn_signature[\"parameters\"][\"properties\"] = schema\n",
    "    return fn_signature\n",
    "\n",
    "\n",
    "def validate_arguments(tool_call: dict, tool_signature: dict) -> dict:\n",
    "    \"\"\"\n",
    "    Validates and converts arguments in the input dictionary to match the expected types.\n",
    "\n",
    "    Args:\n",
    "        tool_call (dict): A dictionary containing the arguments passed to the tool.\n",
    "        tool_signature (dict): The expected function signature and parameter types.\n",
    "\n",
    "    Returns:\n",
    "        dict: The tool call dictionary with the arguments converted to the correct types if necessary.\n",
    "    \"\"\"\n",
    "    properties = tool_signature[\"parameters\"][\"properties\"]\n",
    "\n",
    "    # TODO: This is overly simplified but enough for simple Tools.\n",
    "    type_mapping = {\n",
    "        \"int\": int,\n",
    "        \"str\": str,\n",
    "        \"bool\": bool,\n",
    "        \"float\": float,\n",
    "    }\n",
    "\n",
    "    for arg_name, arg_value in tool_call[\"arguments\"].items():\n",
    "        expected_type = properties[arg_name].get(\"type\")\n",
    "\n",
    "        if not isinstance(arg_value, type_mapping[expected_type]):\n",
    "            tool_call[\"arguments\"][arg_name] = type_mapping[expected_type](arg_value)\n",
    "\n",
    "    return tool_call\n",
    "\n",
    "\n",
    "class Tool:\n",
    "    \"\"\"\n",
    "    A class representing a tool that wraps a callable and its signature.\n",
    "\n",
    "    Attributes:\n",
    "        name (str): The name of the tool (function).\n",
    "        fn (Callable): The function that the tool represents.\n",
    "        fn_signature (str): JSON string representation of the function's signature.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, name: str, fn: Callable, fn_signature: str):\n",
    "        self.name = name\n",
    "        self.fn = fn\n",
    "        self.fn_signature = fn_signature\n",
    "\n",
    "    def __str__(self):\n",
    "        return self.fn_signature\n",
    "\n",
    "    def run(self, **kwargs):\n",
    "        \"\"\"\n",
    "        Executes the tool (function) with provided arguments.\n",
    "\n",
    "        Args:\n",
    "            **kwargs: Keyword arguments passed to the function.\n",
    "\n",
    "        Returns:\n",
    "            The result of the function call.\n",
    "        \"\"\"\n",
    "        return self.fn(**kwargs)\n",
    "\n",
    "\n",
    "def tool(fn: Callable):\n",
    "    \"\"\"\n",
    "    A decorator that wraps a function into a Tool object.\n",
    "\n",
    "    Args:\n",
    "        fn (Callable): The function to be wrapped.\n",
    "\n",
    "    Returns:\n",
    "        Tool: A Tool object containing the function, its name, and its signature.\n",
    "    \"\"\"\n",
    "\n",
    "    def wrapper():\n",
    "        fn_signature = get_fn_signature(fn)\n",
    "        return Tool(\n",
    "            name=fn_signature.get(\"name\"), fn=fn, fn_signature=json.dumps(fn_signature)\n",
    "        )\n",
    "\n",
    "    return wrapper()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dd3f304",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from dataclasses import dataclass\n",
    "import time\n",
    "\n",
    "from colorama import Fore\n",
    "from colorama import Style\n",
    "\n",
    "def completions_create(client, messages: list, model: str) -> str:\n",
    "    \"\"\"\n",
    "    Sends a request to the client's `completions.create` method to interact with the language model.\n",
    "\n",
    "    Args:\n",
    "        client (Groq): The Groq client object\n",
    "        messages (list[dict]): A list of message objects containing chat history for the model.\n",
    "        model (str): The model to use for generating tool calls and responses.\n",
    "\n",
    "    Returns:\n",
    "        str: The content of the model's response.\n",
    "    \"\"\"\n",
    "    response = client.chat.completions.create(messages=messages, model=model)\n",
    "    return str(response.choices[0].message.content)\n",
    "\n",
    "\n",
    "def build_prompt_structure(prompt: str, role: str, tag: str = \"\") -> dict:\n",
    "    \"\"\"\n",
    "    Builds a structured prompt that includes the role and content.\n",
    "\n",
    "    Args:\n",
    "        prompt (str): The actual content of the prompt.\n",
    "        role (str): The role of the speaker (e.g., user, assistant).\n",
    "\n",
    "    Returns:\n",
    "        dict: A dictionary representing the structured prompt.\n",
    "    \"\"\"\n",
    "    if tag:\n",
    "        prompt = f\"<{tag}>{prompt}</{tag}>\"\n",
    "    return {\"role\": role, \"content\": prompt}\n",
    "\n",
    "\n",
    "def update_chat_history(history: list, msg: str, role: str):\n",
    "    \"\"\"\n",
    "    Updates the chat history by appending the latest response.\n",
    "\n",
    "    Args:\n",
    "        history (list): The list representing the current chat history.\n",
    "        msg (str): The message to append.\n",
    "        role (str): The role type (e.g. 'user', 'assistant', 'system')\n",
    "    \"\"\"\n",
    "    history.append(build_prompt_structure(prompt=msg, role=role))\n",
    "\n",
    "\n",
    "class ChatHistory(list):\n",
    "    def __init__(self, messages: list | None = None, total_length: int = -1):\n",
    "        \"\"\"Initialise the queue with a fixed total length.\n",
    "\n",
    "        Args:\n",
    "            messages (list | None): A list of initial messages\n",
    "            total_length (int): The maximum number of messages the chat history can hold.\n",
    "        \"\"\"\n",
    "        if messages is None:\n",
    "            messages = []\n",
    "\n",
    "        super().__init__(messages)\n",
    "        self.total_length = total_length\n",
    "\n",
    "    def append(self, msg: str):\n",
    "        \"\"\"Add a message to the queue.\n",
    "\n",
    "        Args:\n",
    "            msg (str): The message to be added to the queue\n",
    "        \"\"\"\n",
    "        if len(self) == self.total_length:\n",
    "            self.pop(0)\n",
    "        super().append(msg)\n",
    "\n",
    "\n",
    "class FixedFirstChatHistory(ChatHistory):\n",
    "    def __init__(self, messages: list | None = None, total_length: int = -1):\n",
    "        \"\"\"Initialise the queue with a fixed total length.\n",
    "\n",
    "        Args:\n",
    "            messages (list | None): A list of initial messages\n",
    "            total_length (int): The maximum number of messages the chat history can hold.\n",
    "        \"\"\"\n",
    "        super().__init__(messages, total_length)\n",
    "\n",
    "    def append(self, msg: str):\n",
    "        \"\"\"Add a message to the queue. The first messaage will always stay fixed.\n",
    "\n",
    "        Args:\n",
    "            msg (str): The message to be added to the queue\n",
    "        \"\"\"\n",
    "        if len(self) == self.total_length:\n",
    "            self.pop(1)\n",
    "        super().append(msg)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class TagContentResult:\n",
    "    \"\"\"\n",
    "    A data class to represent the result of extracting tag content.\n",
    "\n",
    "    Attributes:\n",
    "        content (List[str]): A list of strings containing the content found between the specified tags.\n",
    "        found (bool): A flag indicating whether any content was found for the given tag.\n",
    "    \"\"\"\n",
    "\n",
    "    content: list[str]\n",
    "    found: bool\n",
    "\n",
    "\n",
    "def extract_tag_content(text: str, tag: str) -> TagContentResult:\n",
    "    \"\"\"\n",
    "    Extracts all content enclosed by specified tags (e.g., <thought>, <response>, etc.).\n",
    "\n",
    "    Parameters:\n",
    "        text (str): The input string containing multiple potential tags.\n",
    "        tag (str): The name of the tag to search for (e.g., 'thought', 'response').\n",
    "\n",
    "    Returns:\n",
    "        dict: A dictionary with the following keys:\n",
    "            - 'content' (list): A list of strings containing the content found between the specified tags.\n",
    "            - 'found' (bool): A flag indicating whether any content was found for the given tag.\n",
    "    \"\"\"\n",
    "    # Build the regex pattern dynamically to find multiple occurrences of the tag\n",
    "    tag_pattern = rf\"<{tag}>(.*?)</{tag}>\"\n",
    "\n",
    "    # Use findall to capture all content between the specified tag\n",
    "    matched_contents = re.findall(tag_pattern, text, re.DOTALL)\n",
    "\n",
    "    # Return the dataclass instance with the result\n",
    "    return TagContentResult(\n",
    "        content=[content.strip() for content in matched_contents],\n",
    "        found=bool(matched_contents),\n",
    "    )\n",
    "\n",
    "\n",
    "def fancy_print(message: str) -> None:\n",
    "    \"\"\"\n",
    "    Displays a fancy print message.\n",
    "\n",
    "    Args:\n",
    "        message (str): The message to display.\n",
    "    \"\"\"\n",
    "    print(Style.BRIGHT + Fore.CYAN + f\"\\n{'=' * 50}\")\n",
    "    print(Fore.MAGENTA + f\"{message}\")\n",
    "    print(Style.BRIGHT + Fore.CYAN + f\"{'=' * 50}\\n\")\n",
    "    time.sleep(0.5)\n",
    "\n",
    "\n",
    "def fancy_step_tracker(step: int, total_steps: int) -> None:\n",
    "    \"\"\"\n",
    "    Displays a fancy step tracker for each iteration of the generation-reflection loop.\n",
    "\n",
    "    Args:\n",
    "        step (int): The current step in the loop.\n",
    "        total_steps (int): The total number of steps in the loop.\n",
    "    \"\"\"\n",
    "    fancy_print(f\"STEP {step + 1}/{total_steps}\")\n",
    "\n",
    "import json\n",
    "from typing import Callable\n",
    "\n",
    "\n",
    "def get_fn_signature(fn: Callable) -> dict:\n",
    "    \"\"\"\n",
    "    Generates the signature for a given function.\n",
    "\n",
    "    Args:\n",
    "        fn (Callable): The function whose signature needs to be extracted.\n",
    "\n",
    "    Returns:\n",
    "        dict: A dictionary containing the function's name, description,\n",
    "              and parameter types.\n",
    "    \"\"\"\n",
    "    fn_signature: dict = {\n",
    "        \"name\": fn.__name__,\n",
    "        \"description\": fn.__doc__,\n",
    "        \"parameters\": {\"properties\": {}},\n",
    "    }\n",
    "    schema = {\n",
    "        k: {\"type\": v.__name__} for k, v in fn.__annotations__.items() if k != \"return\"\n",
    "    }\n",
    "    fn_signature[\"parameters\"][\"properties\"] = schema\n",
    "    return fn_signature\n",
    "\n",
    "\n",
    "def validate_arguments(tool_call: dict, tool_signature: dict) -> dict:\n",
    "    \"\"\"\n",
    "    Validates and converts arguments in the input dictionary to match the expected types.\n",
    "\n",
    "    Args:\n",
    "        tool_call (dict): A dictionary containing the arguments passed to the tool.\n",
    "        tool_signature (dict): The expected function signature and parameter types.\n",
    "\n",
    "    Returns:\n",
    "        dict: The tool call dictionary with the arguments converted to the correct types if necessary.\n",
    "    \"\"\"\n",
    "    properties = tool_signature[\"parameters\"][\"properties\"]\n",
    "\n",
    "    # TODO: This is overly simplified but enough for simple Tools.\n",
    "    type_mapping = {\n",
    "        \"int\": int,\n",
    "        \"str\": str,\n",
    "        \"bool\": bool,\n",
    "        \"float\": float,\n",
    "    }\n",
    "\n",
    "    for arg_name, arg_value in tool_call[\"arguments\"].items():\n",
    "        expected_type = properties[arg_name].get(\"type\")\n",
    "\n",
    "        if not isinstance(arg_value, type_mapping[expected_type]):\n",
    "            tool_call[\"arguments\"][arg_name] = type_mapping[expected_type](arg_value)\n",
    "\n",
    "    return tool_call\n",
    "\n",
    "\n",
    "class Tool:\n",
    "    \"\"\"\n",
    "    A class representing a tool that wraps a callable and its signature.\n",
    "\n",
    "    Attributes:\n",
    "        name (str): The name of the tool (function).\n",
    "        fn (Callable): The function that the tool represents.\n",
    "        fn_signature (str): JSON string representation of the function's signature.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, name: str, fn: Callable, fn_signature: str):\n",
    "        self.name = name\n",
    "        self.fn = fn\n",
    "        self.fn_signature = fn_signature\n",
    "\n",
    "    def __str__(self):\n",
    "        return self.fn_signature\n",
    "\n",
    "    def run(self, **kwargs):\n",
    "        \"\"\"\n",
    "        Executes the tool (function) with provided arguments.\n",
    "\n",
    "        Args:\n",
    "            **kwargs: Keyword arguments passed to the function.\n",
    "\n",
    "        Returns:\n",
    "            The result of the function call.\n",
    "        \"\"\"\n",
    "        return self.fn(**kwargs)\n",
    "\n",
    "\n",
    "def tool(fn: Callable):\n",
    "    \"\"\"\n",
    "    A decorator that wraps a function into a Tool object.\n",
    "\n",
    "    Args:\n",
    "        fn (Callable): The function to be wrapped.\n",
    "\n",
    "    Returns:\n",
    "        Tool: A Tool object containing the function, its name, and its signature.\n",
    "    \"\"\"\n",
    "\n",
    "    def wrapper():\n",
    "        fn_signature = get_fn_signature(fn)\n",
    "        return Tool(\n",
    "            name=fn_signature.get(\"name\"), fn=fn, fn_signature=json.dumps(fn_signature)\n",
    "        )\n",
    "\n",
    "    return wrapper()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "727b6274",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import requests\n",
    "import json\n",
    "import re\n",
    "\n",
    "from colorama import Fore\n",
    "from dotenv import load_dotenv\n",
    "from groq import Groq\n",
    "\n",
    "\n",
    "\n",
    "load_dotenv()\n",
    "\n",
    "\n",
    "TOOL_SYSTEM_PROMPT = \"\"\"\n",
    "You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags.\n",
    "You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug\n",
    "into functions. Pay special attention to the properties 'types'. You should use those types as in a Python dict.\n",
    "For each function call return a json object with function name and arguments within <tool_call></tool_call>\n",
    "XML tags as follows:\n",
    "\n",
    "<tool_call>\n",
    "{\"name\": <function-name>,\"arguments\": <args-dict>,  \"id\": <monotonically-increasing-id>}\n",
    "</tool_call>\n",
    "\n",
    "Here are the available tools:\n",
    "\n",
    "<tools>\n",
    "%s\n",
    "</tools>\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "class ToolAgent:\n",
    "    \"\"\"\n",
    "    The ToolAgent class represents an agent that can interact with a language model and use tools\n",
    "    to assist with user queries. It generates function calls based on user input, validates arguments,\n",
    "    and runs the respective tools.\n",
    "\n",
    "    Attributes:\n",
    "        tools (Tool | list[Tool]): A list of tools available to the agent.\n",
    "        model (str): The model to be used for generating tool calls and responses.\n",
    "        client (Groq): The Groq client used to interact with the language model.\n",
    "        tools_dict (dict): A dictionary mapping tool names to their corresponding Tool objects.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        tools: Tool | list[Tool],\n",
    "        model: str = \"llama3-groq-70b-8192-tool-use-preview\",\n",
    "    ) -> None:\n",
    "        self.client = Groq()\n",
    "        self.model = model\n",
    "        self.tools = tools if isinstance(tools, list) else [tools]\n",
    "        self.tools_dict = {tool.name: tool for tool in self.tools}\n",
    "\n",
    "    def add_tool_signatures(self) -> str:\n",
    "        \"\"\"\n",
    "        Collects the function signatures of all available tools.\n",
    "\n",
    "        Returns:\n",
    "            str: A concatenated string of all tool function signatures in JSON format.\n",
    "        \"\"\"\n",
    "        return \"\".join([tool.fn_signature for tool in self.tools])\n",
    "\n",
    "    def process_tool_calls(self, tool_calls_content: list) -> dict:\n",
    "        \"\"\"\n",
    "        Processes each tool call, validates arguments, executes the tools, and collects results.\n",
    "\n",
    "        Args:\n",
    "            tool_calls_content (list): List of strings, each representing a tool call in JSON format.\n",
    "\n",
    "        Returns:\n",
    "            dict: A dictionary where the keys are tool call IDs and values are the results from the tools.\n",
    "        \"\"\"\n",
    "        observations = {}\n",
    "        for tool_call_str in tool_calls_content:\n",
    "            tool_call = json.loads(tool_call_str)\n",
    "            tool_name = tool_call[\"name\"]\n",
    "            tool = self.tools_dict[tool_name]\n",
    "\n",
    "            print(Fore.GREEN + f\"\\nUsing Tool: {tool_name}\")\n",
    "\n",
    "            # Validate and execute the tool call\n",
    "            validated_tool_call = validate_arguments(\n",
    "                tool_call, json.loads(tool.fn_signature)\n",
    "            )\n",
    "            print(Fore.GREEN + f\"\\nTool call dict: \\n{validated_tool_call}\")\n",
    "\n",
    "            result = tool.run(**validated_tool_call[\"arguments\"])\n",
    "            print(Fore.GREEN + f\"\\nTool result: \\n{result}\")\n",
    "\n",
    "            # Store the result using the tool call ID\n",
    "            observations[validated_tool_call[\"id\"]] = result\n",
    "\n",
    "        return observations\n",
    "\n",
    "    def run(\n",
    "        self,\n",
    "        user_msg: str,\n",
    "    ) -> str:\n",
    "        \"\"\"\n",
    "        Handles the full process of interacting with the language model and executing a tool based on user input.\n",
    "\n",
    "        Args:\n",
    "            user_msg (str): The user's message that prompts the tool agent to act.\n",
    "\n",
    "        Returns:\n",
    "            str: The final output after executing the tool and generating a response from the model.\n",
    "        \"\"\"\n",
    "        user_prompt = build_prompt_structure(prompt=user_msg, role=\"user\")\n",
    "\n",
    "        tool_chat_history = ChatHistory(\n",
    "            [\n",
    "                build_prompt_structure(\n",
    "                    prompt=TOOL_SYSTEM_PROMPT % self.add_tool_signatures(),\n",
    "                    role=\"system\",\n",
    "                ),\n",
    "                user_prompt,\n",
    "            ]\n",
    "        )\n",
    "        agent_chat_history = ChatHistory([user_prompt])\n",
    "\n",
    "        tool_call_response = completions_create(\n",
    "            self.client, messages=tool_chat_history, model=self.model\n",
    "        )\n",
    "        tool_calls = extract_tag_content(str(tool_call_response), \"tool_call\")\n",
    "\n",
    "        if tool_calls.found:\n",
    "            observations = self.process_tool_calls(tool_calls.content)\n",
    "            update_chat_history(\n",
    "                agent_chat_history, f'f\"Observation: {observations}\"', \"user\"\n",
    "            )\n",
    "\n",
    "        return completions_create(self.client, agent_chat_history, self.model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9413902-e3ea-4c0a-bfd2-180d69ba5cd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def fetch_top_hacker_news_stories(top_n: int):\n",
    "    \"\"\"\n",
    "    Fetch the top stories from Hacker News.\n",
    "\n",
    "    This function retrieves the top `top_n` stories from Hacker News using the Hacker News API. \n",
    "    Each story contains the title, URL, score, author, and time of submission. The data is fetched \n",
    "    from the official Firebase Hacker News API, which returns story details in JSON format.\n",
    "\n",
    "    Args:\n",
    "        top_n (int): The number of top stories to retrieve.\n",
    "    \"\"\"\n",
    "    top_stories_url = 'https://hacker-news.firebaseio.com/v0/topstories.json'\n",
    "    \n",
    "    try:\n",
    "        response = requests.get(top_stories_url)\n",
    "        response.raise_for_status()  # Check for HTTP errors\n",
    "        \n",
    "        # Get the top story IDs\n",
    "        top_story_ids = response.json()[:top_n]\n",
    "        \n",
    "        top_stories = []\n",
    "        \n",
    "        # For each story ID, fetch the story details\n",
    "        for story_id in top_story_ids:\n",
    "            story_url = f'https://hacker-news.firebaseio.com/v0/item/{story_id}.json'\n",
    "            story_response = requests.get(story_url)\n",
    "            story_response.raise_for_status()  # Check for HTTP errors\n",
    "            story_data = story_response.json()\n",
    "            \n",
    "            # Append the story title and URL (or other relevant info) to the list\n",
    "            top_stories.append({\n",
    "                'title': story_data.get('title', 'No title'),\n",
    "                'url': story_data.get('url', 'No URL available'),\n",
    "            })\n",
    "        \n",
    "        return json.dumps(top_stories)\n",
    "\n",
    "    except requests.exceptions.RequestException as e:\n",
    "        print(f\"An error occurred: {e}\")\n",
    "        return []"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73f75359-1e8a-4317-92dd-40dd1cf36e97",
   "metadata": {},
   "source": [
    "If we run this Python function, we'll obtain the top HN stories, as you can see below (the top 5 in this case)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aad2bbed-549e-4c0e-91fd-37b4694e0b50",
   "metadata": {},
   "outputs": [],
   "source": [
    "json.loads(fetch_top_hacker_news_stories(top_n=5))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb587d13-b312-45b5-af56-4f009c11eeda",
   "metadata": {},
   "source": [
    "To transform the `fetch_top_hacker_news_stories` function into a Tool, we can use the `tool` decorator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4616e412-d4a8-4fe5-bcb1-dd00ce48640a",
   "metadata": {},
   "outputs": [],
   "source": [
    "hn_tool = tool(fetch_top_hacker_news_stories)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f438638-a933-414f-9d00-53c37f041f16",
   "metadata": {},
   "source": [
    "The Tool has the following parameters: a `name`, a `fn_signature` and the `fn` (this is the function we are going to call, this case `fetch_top_hacker_news_stories`)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df16bfa5-0ed4-46e1-b262-006f36fb8e78",
   "metadata": {},
   "outputs": [],
   "source": [
    "hn_tool.name"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3209e3e0-b59c-4b0e-b075-8fcbf9d21516",
   "metadata": {},
   "source": [
    "By default, the tool gets its name from the function name."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0da95e0-10a8-4d17-aae7-ed3cc20abb03",
   "metadata": {},
   "outputs": [],
   "source": [
    "json.loads(hn_tool.fn_signature)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5760bf7-7d9a-4c79-bc87-6469040250b6",
   "metadata": {},
   "source": [
    "As you can see, the function signature has been automatically generated. It contains the `name`, a `description` (taken from the docstrings) and the `parameters`, whose types come from the tying annotations. Now that we have a tool, let's run the agent."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "043ad8ba-7789-468a-aafd-60c10bd21135",
   "metadata": {},
   "source": [
    "### The `ToolAgent`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "065e04e5-50af-4452-9086-eae08a12e8cf",
   "metadata": {},
   "source": [
    "To create the agent, we just need to pass a list of tools (in this case, just one)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a303211-f2a6-43c0-85aa-081fb0be2bbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "tool_agent = ToolAgent(tools=[hn_tool])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9eabae6-2b5c-407e-9e43-88e5e4844e9e",
   "metadata": {},
   "source": [
    "A quick check to see that everything works fine. If we ask the agent something unrelated to Hacker News, it shouldn't use the tool."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92c706fd-0a4b-46be-bbb3-c02618dbf677",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = tool_agent.run(user_msg=\"Tell me your name\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be02a976-1e72-40ad-9ada-460148ca65d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c862a34-3cc9-428b-a246-d98effc998a5",
   "metadata": {},
   "source": [
    "Now, let's ask for specific information about Hacker News."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b74bbc64-8943-4ae3-9928-6230ead61e77",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = tool_agent.run(user_msg=\"Tell me the top 5 Hacker News stories right now\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53476bff-812d-4e56-afb9-de21474f6580",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f70a8059-9637-45b5-8050-ea7ba4995407",
   "metadata": {},
   "source": [
    "---\n",
    "There you have it!! A fully functional Tool!! 🛠️"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
